#!/usr/bin/env python
# -*- coding: utf-8; py-indent-offset:4 -*-
###############################################################################
#
# Copyright (C) 2015-2020 Daniel Rodriguez
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
###############################################################################
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

from collections import OrderedDict
import datetime
import backtrader as bt
from backtrader.utils.datehelper import *
from backtrader.utils.py3 import range
from backtrader import Analyzer

"""
分阶段统计收益
"""


class BenchmarkReturn(Analyzer):
    '''
    This analyzer calculates the AnnualReturns by looking at the beginning
    and end of the year

    Params:

      - (None)

    Member Attributes:

      - ``rets``: list of calculated annual returns

      - ``ret``: dictionary (key: year) of annual returns

    **get_analysis**:

      - Returns a dictionary of annual returns (key: year)
    '''

    def curr_dt_key(self, dt, dt_mode):
        if dt_mode == bt.TimeFrame.Months:
            dt_key = dt.year * 100 + dt.month
        elif dt_mode == bt.TimeFrame.Years:
            dt_key = dt.year
        else:
            dt_key = -1
        return dt_key

    def cal_ret(self, data, dt_mode=bt.TimeFrame.Months):
        returns = OrderedDict()
        pre_dt = None
        pre_close = None
        for idx in range(len(data) - 1, -1, -1):
            dt = data.datetime.date(-idx)
            curr_dt = self.curr_dt_key(dt, dt_mode=dt_mode)
            if pre_dt is None:
                pre_dt = curr_dt
                pre_close = data.open[-idx]  # 第一个用开盘价计算

            if pre_dt and pre_dt != curr_dt:
                returns[pre_dt] = data.close[-idx-1] / pre_close - 1  # 切换的时候用上一个最后的收盘价
                pre_close = data.close[-idx-1]
                pre_dt = curr_dt
        returns[pre_dt] = data.close[0] / pre_close - 1
        return returns

    def cal_range_ret(self, data, filter_dt, key):
        returns = OrderedDict()
        start_close = float('NaN')
        for idx in range(len(data) - 1, -1, -1):
            dt = data.datetime.date(-idx)
            if dt < filter_dt:
                continue
            else:
                start_close = data.open[-idx]
                break
        returns[key] = data.close[0] / start_close - 1
        return returns

    def stop(self):
        self.ret = OrderedDict()

        returns = OrderedDict()
        weights = 1 / len(self.datas)
        end_time = self.data.datetime.date(0)

        for d in self.datas:
            symbol = d._name
            months = self.cal_ret(d, dt_mode=bt.TimeFrame.Months)
            years = self.cal_ret(d, dt_mode=bt.TimeFrame.Years)

            for m in ["1month", "3month", "6month", "ytd"]:
                month_date, _month_key = recent_date(end_time, m)
                recently_month_stat = self.cal_range_ret(d, month_date, _month_key)
                if _month_key not in returns:
                    returns[_month_key] = recently_month_stat[_month_key] * weights
                else:
                    returns[_month_key] += recently_month_stat[_month_key] * weights

            for m in months.keys():
                if m not in returns:
                    returns[m] = months[m] * weights
                else:
                    returns[m] += months[m] * weights

            for y in years.keys():
                if y not in returns:
                    returns[y] = years[y] * weights
                else:
                    returns[y] += years[y] * weights

        self.ret.update(returns)

    def get_analysis(self):
        return self.ret
